-
Notifications
You must be signed in to change notification settings - Fork 470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add optional reward scaling #95
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Just needs more comments.
@@ -6,7 +6,7 @@ model: | |||
|
|||
train: | |||
seq_length: 48 # Size of LM context | |||
epochs: 1000 # Train for max(epochs, total_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank god we finally decreased this haha
scores = torch.as_tensor(self.score(texts), device=samples.device) | ||
stats["exp_score_time"] = time() - exp_score_time | ||
|
||
if self.ref_mean is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments here about what this does would be helpful :)
delta = xs_mean - self.mean | ||
tot_count = self.count + xs_count | ||
|
||
m_a = self.var * self.count |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hate having lots of math and no intuitive explanation of what the math is doing as comments. Please fix.
Can you post the W&B link to a run where we can confirm rescaling works? Thanks. |
Agreed let's make sure to do this on all algo changing PRs. |
Run looks good to me. Add comments and we can merge! |
@@ -35,6 +35,8 @@ method: | |||
cliprange: 0.2 # clip range | |||
cliprange_value: 0.2 # clip range | |||
vf_coef: 2.3 # value term weight | |||
scale_reward: True | |||
clip_reward: 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't reward generally be in the range [-1,1]?
stats["exp_score_time"] = time() - exp_score_time | ||
|
||
if self.ref_mean is None: | ||
self.ref_mean, self.ref_std = scores.mean(), scores.std() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps naming this ref_mean is a bit misleading? It is not the mean of the reference model but rather the mean of the training model.
|
||
|
||
def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch.Tensor: | ||
"""Whitens values""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity do we have a reference for whitening? (Some blog post, arxiv paper)
) | ||
|
||
|
||
class RunningMoments: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice
|
||
|
||
class RunningMoments: | ||
def __init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we precompute the mean and var of our initial reward distribution ahead of time do we have a way of incorporating that?
Approved though I'd like to know if pre-computed mean, var of baseline rewards can be used as well |
In the spirit of #48
reward scaling from https://github.com/DLR-RM/stable-baselines3/blob/d5d1a02c15cdce868c72bbc94913e66fdd2efd3a/stable_baselines3/common/vec_env/vec_normalize.py#L220
minibatch whitening from https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/spinup/algos/pytorch/ppo/ppo.py#L80
https://wandb.ai/sorry/public/reports/mean_reward-22-11-17-01-59-30---VmlldzoyOTg0ODc3